{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/main/lib/python3.9/site-packages/thinc/compat.py:36: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
      "  hasattr(torch, \"has_mps\")\n",
      "/home/ben/miniconda3/envs/main/lib/python3.9/site-packages/thinc/compat.py:37: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n",
      "  and torch.has_mps  # type: ignore[attr-defined]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import random\n",
    "import numpy as np\n",
    "import spacy\n",
    "import datasets\n",
    "import torchtext\n",
    "import tqdm\n",
    "import evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1234\n",
    "\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_dataset(\"bentrevett/multi30k\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, valid_data, test_data = (\n",
    "    dataset[\"train\"],\n",
    "    dataset[\"validation\"],\n",
    "    dataset[\"test\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/main/lib/python3.9/site-packages/spacy/util.py:910: UserWarning: [W095] Model 'en_core_web_sm' (3.5.0) was trained with spaCy v3.5.0 and may not be 100% compatible with the current version (3.7.2). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n",
      "  warnings.warn(warn_msg)\n",
      "/home/ben/miniconda3/envs/main/lib/python3.9/site-packages/spacy/util.py:910: UserWarning: [W095] Model 'de_core_news_sm' (3.5.0) was trained with spaCy v3.5.0 and may not be 100% compatible with the current version (3.7.2). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n",
      "  warnings.warn(warn_msg)\n"
     ]
    }
   ],
   "source": [
    "en_nlp = spacy.load(\"en_core_web_sm\")\n",
    "de_nlp = spacy.load(\"de_core_news_sm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):\n",
    "    en_tokens = [token.text for token in en_nlp.tokenizer(example[\"en\"])][:max_length]\n",
    "    de_tokens = [token.text for token in de_nlp.tokenizer(example[\"de\"])][:max_length]\n",
    "    if lower:\n",
    "        en_tokens = [token.lower() for token in en_tokens]\n",
    "        de_tokens = [token.lower() for token in de_tokens]\n",
    "    en_tokens = [sos_token] + en_tokens + [eos_token]\n",
    "    de_tokens = [sos_token] + de_tokens + [eos_token]\n",
    "    return {\"en_tokens\": en_tokens, \"de_tokens\": de_tokens}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "811b88c6d1d548ceb1f725a8bab9d886",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/29000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a00ad980f545429fb232773c4db6948d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1014 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "93be8b6475084a8c80c4f8e3902e624d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "max_length = 1_000\n",
    "lower = True\n",
    "sos_token = \"<sos>\"\n",
    "eos_token = \"<eos>\"\n",
    "\n",
    "fn_kwargs = {\n",
    "    \"en_nlp\": en_nlp,\n",
    "    \"de_nlp\": de_nlp,\n",
    "    \"max_length\": max_length,\n",
    "    \"lower\": lower,\n",
    "    \"sos_token\": sos_token,\n",
    "    \"eos_token\": eos_token,\n",
    "}\n",
    "\n",
    "train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)\n",
    "valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)\n",
    "test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_freq = 2\n",
    "unk_token = \"<unk>\"\n",
    "pad_token = \"<pad>\"\n",
    "\n",
    "special_tokens = [\n",
    "    unk_token,\n",
    "    pad_token,\n",
    "    sos_token,\n",
    "    eos_token,\n",
    "]\n",
    "\n",
    "en_vocab = torchtext.vocab.build_vocab_from_iterator(\n",
    "    train_data[\"en_tokens\"],\n",
    "    min_freq=min_freq,\n",
    "    specials=special_tokens,\n",
    ")\n",
    "\n",
    "de_vocab = torchtext.vocab.build_vocab_from_iterator(\n",
    "    train_data[\"de_tokens\"],\n",
    "    min_freq=min_freq,\n",
    "    specials=special_tokens,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert en_vocab[unk_token] == de_vocab[unk_token]\n",
    "assert en_vocab[pad_token] == de_vocab[pad_token]\n",
    "\n",
    "unk_index = en_vocab[unk_token]\n",
    "pad_index = en_vocab[pad_token]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "en_vocab.set_default_index(unk_index)\n",
    "de_vocab.set_default_index(unk_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def numericalize_example(example, en_vocab, de_vocab):\n",
    "    en_ids = en_vocab.lookup_indices(example[\"en_tokens\"])\n",
    "    de_ids = de_vocab.lookup_indices(example[\"de_tokens\"])\n",
    "    return {\"en_ids\": en_ids, \"de_ids\": de_ids}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dde45893f8fa4826890cd1cac51b1512",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/29000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eee99d0341454c669b570f0c6588c432",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1014 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0f6f712581a4ea6a4d35ce1dbb57527",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fn_kwargs = {\"en_vocab\": en_vocab, \"de_vocab\": de_vocab}\n",
    "\n",
    "train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)\n",
    "valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)\n",
    "test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = \"torch\"\n",
    "format_columns = [\"en_ids\", \"de_ids\"]\n",
    "\n",
    "train_data = train_data.with_format(\n",
    "    type=data_type, columns=format_columns, output_all_columns=True\n",
    ")\n",
    "\n",
    "valid_data = valid_data.with_format(\n",
    "    type=data_type,\n",
    "    columns=format_columns,\n",
    "    output_all_columns=True,\n",
    ")\n",
    "\n",
    "test_data = test_data.with_format(\n",
    "    type=data_type,\n",
    "    columns=format_columns,\n",
    "    output_all_columns=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_collate_fn(pad_index):\n",
    "    def collate_fn(batch):\n",
    "        batch_en_ids = [example[\"en_ids\"] for example in batch]\n",
    "        batch_de_ids = [example[\"de_ids\"] for example in batch]\n",
    "        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)\n",
    "        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)\n",
    "        batch = {\n",
    "            \"en_ids\": batch_en_ids,\n",
    "            \"de_ids\": batch_de_ids,\n",
    "        }\n",
    "        return batch\n",
    "\n",
    "    return collate_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
    "    collate_fn = get_collate_fn(pad_index)\n",
    "    data_loader = torch.utils.data.DataLoader(\n",
    "        dataset=dataset,\n",
    "        batch_size=batch_size,\n",
    "        collate_fn=collate_fn,\n",
    "        shuffle=shuffle,\n",
    "    )\n",
    "\n",
    "    return data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
    "valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
    "test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, embedding_dim, hidden_dim, dropout):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.embedding = nn.Embedding(input_dim, embedding_dim)\n",
    "        self.rnn = nn.GRU(embedding_dim, hidden_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, src):\n",
    "        # src = [src length, batch size]\n",
    "        embedded = self.dropout(self.embedding(src))\n",
    "        # embedded = [src length, batch size, embedding dim]\n",
    "        outputs, hidden = self.rnn(embedded)  # no cell state in GRU!\n",
    "        # outputs = [src length, batch size, hidden dim * n directions]\n",
    "        # hidden = [n layers * n directions, batch size, hidden dim]\n",
    "        # outputs are always from the top hidden layer\n",
    "        return hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_dim, embedding_dim, hidden_dim, dropout):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.embedding = nn.Embedding(output_dim, embedding_dim)\n",
    "        self.rnn = nn.GRU(embedding_dim + hidden_dim, hidden_dim)\n",
    "        self.fc_out = nn.Linear(embedding_dim + hidden_dim * 2, output_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, input, hidden, context):\n",
    "        # input = [batch size]\n",
    "        # hidden = [n layers * n directions, batch size, hidden dim]\n",
    "        # context = [n layers * n directions, batch size, hidden dim]\n",
    "        # n layers and n directions in the decoder will both always be 1, therefore:\n",
    "        # hidden = [1, batch size, hidden dim]\n",
    "        # context = [1, batch size, hidden dim]\n",
    "        input = input.unsqueeze(0)\n",
    "        # input = [1, batch size]\n",
    "        embedded = self.dropout(self.embedding(input))\n",
    "        # embedded = [1, batch size, embedding dim]\n",
    "        emb_con = torch.cat((embedded, context), dim=2)\n",
    "        # emb_con = [1, batch size, embedding dim + hidden dim]\n",
    "        output, hidden = self.rnn(emb_con, hidden)\n",
    "        # output = [seq len, batch size, hidden dim * n directions]\n",
    "        # hidden = [n layers * n directions, batch size, hidden dim]\n",
    "        # seq len, n layers and n directions will always be 1 in this decoder, therefore:\n",
    "        # output = [1, batch size, hidden dim]\n",
    "        # hidden = [1, batch size, hidden dim]\n",
    "        output = torch.cat(\n",
    "            (embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), dim=1\n",
    "        )\n",
    "        # output = [batch size, embedding dim + hidden dim * 2]\n",
    "        prediction = self.fc_out(output)\n",
    "        # prediction = [batch size, output dim]\n",
    "        return prediction, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder, device):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.device = device\n",
    "        assert (\n",
    "            encoder.hidden_dim == decoder.hidden_dim\n",
    "        ), \"Hidden dimensions of encoder and decoder must be equal!\"\n",
    "\n",
    "    def forward(self, src, trg, teacher_forcing_ratio):\n",
    "        # src = [src length, batch size]\n",
    "        # trg = [trg length, batch size]\n",
    "        # teacher_forcing_ratio is probability to use teacher forcing\n",
    "        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n",
    "        batch_size = trg.shape[1]\n",
    "        trg_length = trg.shape[0]\n",
    "        trg_vocab_size = self.decoder.output_dim\n",
    "        # tensor to store decoder outputs\n",
    "        outputs = torch.zeros(trg_length, batch_size, trg_vocab_size).to(self.device)\n",
    "        # last hidden state of the encoder is the context\n",
    "        context = self.encoder(src)\n",
    "        # context = [n layers * n directions, batch size, hidden dim]\n",
    "        # context also used as the initial hidden state of the decoder\n",
    "        hidden = context\n",
    "        # hidden = [n layers * n directions, batch size, hidden dim]\n",
    "        # first input to the decoder is the <sos> tokens\n",
    "        input = trg[0, :]\n",
    "        for t in range(1, trg_length):\n",
    "            # insert input token embedding, previous hidden state and the context state\n",
    "            # receive output tensor (predictions) and new hidden state\n",
    "            output, hidden = self.decoder(input, hidden, context)\n",
    "            # output = [batch size, output dim]\n",
    "            # hidden = [1, batch size, hidden dim]\n",
    "            # place predictions in a tensor holding predictions for each token\n",
    "            outputs[t] = output\n",
    "            # decide if we are going to use teacher forcing or not\n",
    "            teacher_force = random.random() < teacher_forcing_ratio\n",
    "            # get the highest predicted token from our predictions\n",
    "            top1 = output.argmax(1)\n",
    "            # if teacher forcing, use actual next token as next input\n",
    "            # if not, use predicted token\n",
    "            input = trg[t] if teacher_force else top1\n",
    "            # input = [batch size]\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = len(de_vocab)\n",
    "output_dim = len(en_vocab)\n",
    "encoder_embedding_dim = 256\n",
    "decoder_embedding_dim = 256\n",
    "hidden_dim = 512\n",
    "encoder_dropout = 0.5\n",
    "decoder_dropout = 0.5\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "encoder = Encoder(\n",
    "    input_dim,\n",
    "    encoder_embedding_dim,\n",
    "    hidden_dim,\n",
    "    encoder_dropout,\n",
    ")\n",
    "\n",
    "decoder = Decoder(\n",
    "    output_dim,\n",
    "    decoder_embedding_dim,\n",
    "    hidden_dim,\n",
    "    decoder_dropout,\n",
    ")\n",
    "\n",
    "model = Seq2Seq(encoder, decoder, device).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Seq2Seq(\n",
       "  (encoder): Encoder(\n",
       "    (embedding): Embedding(7853, 256)\n",
       "    (rnn): GRU(256, 512)\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       "  (decoder): Decoder(\n",
       "    (embedding): Embedding(5893, 256)\n",
       "    (rnn): GRU(768, 512)\n",
       "    (fc_out): Linear(in_features=1280, out_features=5893, bias=True)\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_weights(m):\n",
    "    for name, param in m.named_parameters():\n",
    "        nn.init.normal_(param.data, mean=0, std=0.01)\n",
    "\n",
    "\n",
    "model.apply(init_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 14,219,781 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "\n",
    "print(f\"The model has {count_parameters(model):,} trainable parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss(ignore_index=pad_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_fn(\n",
    "    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device\n",
    "):\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for i, batch in enumerate(data_loader):\n",
    "        src = batch[\"de_ids\"].to(device)\n",
    "        trg = batch[\"en_ids\"].to(device)\n",
    "        # src = [src length, batch size]\n",
    "        # trg = [trg length, batch size]\n",
    "        optimizer.zero_grad()\n",
    "        output = model(src, trg, teacher_forcing_ratio)\n",
    "        # output = [trg length, batch size, trg vocab size]\n",
    "        output_dim = output.shape[-1]\n",
    "        output = output[1:].view(-1, output_dim)\n",
    "        # output = [(trg length - 1) * batch size, trg vocab size]\n",
    "        trg = trg[1:].view(-1)\n",
    "        # trg = [(trg length - 1) * batch size]\n",
    "        loss = criterion(output, trg)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "    return epoch_loss / len(data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_fn(model, data_loader, criterion, device):\n",
    "    model.eval()\n",
    "    epoch_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for i, batch in enumerate(data_loader):\n",
    "            src = batch[\"de_ids\"].to(device)\n",
    "            trg = batch[\"en_ids\"].to(device)\n",
    "            # src = [src length, batch size]\n",
    "            # trg = [trg length, batch size]\n",
    "            output = model(src, trg, 0)  # turn off teacher forcing\n",
    "            # output = [trg length, batch size, trg vocab size]\n",
    "            output_dim = output.shape[-1]\n",
    "            output = output[1:].view(-1, output_dim)\n",
    "            # output = [(trg length - 1) * batch size, trg vocab size]\n",
    "            trg = trg[1:].view(-1)\n",
    "            # trg = [(trg length - 1) * batch size]\n",
    "            loss = criterion(output, trg)\n",
    "            epoch_loss += loss.item()\n",
    "    return epoch_loss / len(data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|██████████████▋                                                                                                                                    | 1/10 [00:16<02:29, 16.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   5.065 | Train PPL: 158.369\n",
      "\tValid Loss:   5.033 | Valid PPL: 153.368\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|█████████████████████████████▍                                                                                                                     | 2/10 [00:32<02:10, 16.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   4.300 | Train PPL:  73.676\n",
      "\tValid Loss:   4.679 | Valid PPL: 107.619\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|████████████████████████████████████████████                                                                                                       | 3/10 [00:48<01:53, 16.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   3.868 | Train PPL:  47.865\n",
      "\tValid Loss:   4.343 | Valid PPL:  76.914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|██████████████████████████████████████████████████████████▊                                                                                        | 4/10 [01:04<01:36, 16.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   3.469 | Train PPL:  32.116\n",
      "\tValid Loss:   4.010 | Valid PPL:  55.130\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████████████████████████████████████████████████████████████████████████▌                                                                         | 5/10 [01:21<01:20, 16.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   3.111 | Train PPL:  22.434\n",
      "\tValid Loss:   3.829 | Valid PPL:  46.027\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|████████████████████████████████████████████████████████████████████████████████████████▏                                                          | 6/10 [01:37<01:04, 16.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   2.842 | Train PPL:  17.156\n",
      "\tValid Loss:   3.716 | Valid PPL:  41.116\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████▉                                            | 7/10 [01:53<00:48, 16.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   2.608 | Train PPL:  13.576\n",
      "\tValid Loss:   3.663 | Valid PPL:  38.986\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                             | 8/10 [02:09<00:32, 16.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   2.367 | Train PPL:  10.661\n",
      "\tValid Loss:   3.672 | Valid PPL:  39.319\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎              | 9/10 [02:25<00:16, 16.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   2.178 | Train PPL:   8.827\n",
      "\tValid Loss:   3.642 | Valid PPL:  38.160\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:41<00:00, 16.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss:   1.988 | Train PPL:   7.299\n",
      "\tValid Loss:   3.725 | Valid PPL:  41.456\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 10\n",
    "clip = 1.0\n",
    "teacher_forcing_ratio = 0.5\n",
    "\n",
    "best_valid_loss = float(\"inf\")\n",
    "\n",
    "for epoch in tqdm.tqdm(range(n_epochs)):\n",
    "    train_loss = train_fn(\n",
    "        model,\n",
    "        train_data_loader,\n",
    "        optimizer,\n",
    "        criterion,\n",
    "        clip,\n",
    "        teacher_forcing_ratio,\n",
    "        device,\n",
    "    )\n",
    "    valid_loss = evaluate_fn(\n",
    "        model,\n",
    "        valid_data_loader,\n",
    "        criterion,\n",
    "        device,\n",
    "    )\n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), \"tut2-model.pt\")\n",
    "    print(f\"\\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}\")\n",
    "    print(f\"\\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Test Loss: 3.623 | Test PPL:  37.460 |\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load(\"tut2-model.pt\"))\n",
    "\n",
    "test_loss = evaluate_fn(model, test_data_loader, criterion, device)\n",
    "\n",
    "print(f\"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate_sentence(\n",
    "    sentence,\n",
    "    model,\n",
    "    en_nlp,\n",
    "    de_nlp,\n",
    "    en_vocab,\n",
    "    de_vocab,\n",
    "    lower,\n",
    "    sos_token,\n",
    "    eos_token,\n",
    "    device,\n",
    "    max_output_length=25,\n",
    "):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        if isinstance(sentence, str):\n",
    "            tokens = [token.text for token in de_nlp.tokenizer(sentence)]\n",
    "        else:\n",
    "            tokens = [token for token in sentence]\n",
    "        if lower:\n",
    "            tokens = [token.lower() for token in tokens]\n",
    "        tokens = [sos_token] + tokens + [eos_token]\n",
    "        ids = de_vocab.lookup_indices(tokens)\n",
    "        tensor = torch.LongTensor(ids).unsqueeze(-1).to(device)\n",
    "        context = model.encoder(tensor)\n",
    "        hidden = context\n",
    "        inputs = en_vocab.lookup_indices([sos_token])\n",
    "        for _ in range(max_output_length):\n",
    "            inputs_tensor = torch.LongTensor([inputs[-1]]).to(device)\n",
    "            output, hidden = model.decoder(inputs_tensor, hidden, context)\n",
    "            predicted_token = output.argmax(-1).item()\n",
    "            inputs.append(predicted_token)\n",
    "            if predicted_token == en_vocab[eos_token]:\n",
    "                break\n",
    "        tokens = en_vocab.lookup_tokens(inputs)\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.',\n",
       " 'A man in an orange hat starring at something.')"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = test_data[0][\"de\"]\n",
    "expected_translation = test_data[0][\"en\"]\n",
    "\n",
    "sentence, expected_translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "translation = translate_sentence(\n",
    "    sentence,\n",
    "    model,\n",
    "    en_nlp,\n",
    "    de_nlp,\n",
    "    en_vocab,\n",
    "    de_vocab,\n",
    "    lower,\n",
    "    sos_token,\n",
    "    eos_token,\n",
    "    device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<sos>',\n",
       " 'a',\n",
       " 'man',\n",
       " 'in',\n",
       " 'a',\n",
       " 'white',\n",
       " 'hat',\n",
       " 'is',\n",
       " 'something',\n",
       " 'something',\n",
       " '.',\n",
       " '<eos>']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = \"Ein Mann sieht sich einen Film an.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "translation = translate_sentence(\n",
    "    sentence,\n",
    "    model,\n",
    "    en_nlp,\n",
    "    de_nlp,\n",
    "    en_vocab,\n",
    "    de_vocab,\n",
    "    lower,\n",
    "    sos_token,\n",
    "    eos_token,\n",
    "    device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<sos>', 'a', 'man', 'looking', 'at', 'a', 'grill', '.', '<eos>']"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 212.75it/s]\n"
     ]
    }
   ],
   "source": [
    "translations = [\n",
    "    translate_sentence(\n",
    "        example[\"de\"],\n",
    "        model,\n",
    "        en_nlp,\n",
    "        de_nlp,\n",
    "        en_vocab,\n",
    "        de_vocab,\n",
    "        lower,\n",
    "        sos_token,\n",
    "        eos_token,\n",
    "        device,\n",
    "    )\n",
    "    for example in tqdm.tqdm(test_data)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "bleu = evaluate.load(\"bleu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = [\" \".join(translation[1:-1]) for translation in translations]\n",
    "\n",
    "references = [[example[\"en\"]] for example in test_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokenizer_fn(nlp, lower):\n",
    "    def tokenizer_fn(s):\n",
    "        tokens = [token.text for token in nlp.tokenizer(s)]\n",
    "        if lower:\n",
    "            tokens = [token.lower() for token in tokens]\n",
    "        return tokens\n",
    "\n",
    "    return tokenizer_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_fn = get_tokenizer_fn(en_nlp, lower)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = bleu.compute(\n",
    "    predictions=predictions, references=references, tokenizer=tokenizer_fn\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'bleu': 0.20611261550518511,\n",
       " 'precisions': [0.5521618710570857,\n",
       "  0.2777962993832305,\n",
       "  0.14720858337879614,\n",
       "  0.08141628325665133],\n",
       " 'brevity_penalty': 0.9953945430070688,\n",
       " 'length_ratio': 0.9954051156379231,\n",
       " 'translation_length': 12998,\n",
       " 'reference_length': 13058}"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
